{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/mhassan/anaconda3/envs/ml/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import sys\n",
    "sys.path.append(\"../utilities/\")\n",
    "from utility import FeatureGenerator\n",
    "import h5py\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "import numpy as np\n",
    "from sklearn.metrics import classification_report\n",
    "from sklearn.externals import joblib\n",
    "import glob\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_class(label):\n",
    "    if label == \"non-inhibitor\":\n",
    "        return 0\n",
    "    else:\n",
    "        return 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract features\n",
    "def get_features(SMILES):\n",
    "    try:\n",
    "        feat_gen = FeatureGenerator(SMILES)\n",
    "        return feat_gen.toTPATF()\n",
    "    except:\n",
    "        return float('NaN')\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def doit(_id):\n",
    "    print(\"Generating models for \" + _id)\n",
    "    data_file = _id + \".h5\"\n",
    "    model_file = _id + \".mdl\"\n",
    "    \n",
    "    if not os.path.isfile(data_file):\n",
    "        train_df = pd.read_excel(\"TrainSet_\" + _id + \".xls\")\n",
    "        valid_df = pd.read_excel(\"ValidationSet_\" + _id + \".xls\")\n",
    "\n",
    "        train_df['class'] = train_df.Labels.apply(get_class)\n",
    "        valid_df['class'] = valid_df.Labels.apply(get_class)\n",
    "\n",
    "        train_df['features'] = train_df.SMILES.apply(get_features)\n",
    "        valid_df['features'] = valid_df.SMILES.apply(get_features)\n",
    "\n",
    "        train_df = train_df.dropna()\n",
    "        valid_df = valid_df.dropna()\n",
    "\n",
    "        train_x, train_y = np.stack(train_df['features'].values).astype(np.float32), train_df['class'].values\n",
    "        valid_x, valid_y = np.stack(valid_df['features'].values).astype(np.float32), valid_df['class'].values\n",
    "        print(train_x.shape, train_y.shape)\n",
    "        print(valid_x.shape, valid_y.shape)\n",
    "\n",
    "        h5f = h5py.File(data_file, \"w\")\n",
    "        h5f.create_dataset(\"train_x\", data=train_x)\n",
    "        h5f.create_dataset(\"train_y\", data=train_y)\n",
    "        h5f.create_dataset(\"valid_x\", data=valid_x)\n",
    "        h5f.create_dataset(\"valid_y\", data=valid_y)\n",
    "        h5f.close()\n",
    "    else:\n",
    "        print(\"Existing data file found. Not data is being generated\")\n",
    "    \n",
    "    if not os.path.isfile(model_file):\n",
    "        clf = RandomForestClassifier(class_weight=\"balanced\")\n",
    "        param_grid = {\"n_estimators\": [i for i in range(100, 1001, 100)]}\n",
    "        grid_clf = GridSearchCV(estimator=clf, cv=5, param_grid=param_grid, verbose=True, n_jobs=-1)\n",
    "        grid_clf.fit(train_x, train_y)\n",
    "\n",
    "        model = grid_clf.best_estimator_\n",
    "        print(model)\n",
    "        \n",
    "        print(\"REPORT ON \" + _id)\n",
    "        print(\"TRAINING PERFORMANCE\")\n",
    "        y_pred = model.predict(train_x)\n",
    "        print(classification_report(y_true=train_y, y_pred=y_pred))\n",
    "        \n",
    "        print(\"TEST PERFORMANCE\")\n",
    "        y_pred = model.predict(valid_x)\n",
    "        print(classification_report(y_true=valid_y, y_pred=y_pred))\n",
    "    \n",
    "        joblib.dump(model, model_file)\n",
    "    else:\n",
    "        print(\"Existing model foudn for \" + _id + \". No model will be trained\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['TrainSet_CYP2C9.xls',\n",
       " 'TrainSet_CYP2C19.xls',\n",
       " 'TrainSet_CYP2D6.xls',\n",
       " 'TrainSet_CYP3A4.xls',\n",
       " 'TrainSet_CYP1A2.xls']"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "training_files = glob.glob(\"Train*\")\n",
    "training_files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['CYP2C9', 'CYP2C19', 'CYP2D6', 'CYP3A4', 'CYP1A2']\n"
     ]
    }
   ],
   "source": [
    "_ids = [[i for i in os.path.splitext(f)][0].split('_')[1] for f in training_files]\n",
    "print(_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating models for CYP2C9\n",
      "(12127, 2692) (12127,)\n",
      "(2579, 2692) (2579,)\n",
      "Fitting 5 folds for each of 10 candidates, totalling 50 fits\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=-1)]: Done  41 out of  50 | elapsed:  2.1min remaining:   27.6s\n",
      "[Parallel(n_jobs=-1)]: Done  50 out of  50 | elapsed:  2.3min finished\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RandomForestClassifier(bootstrap=True, class_weight='balanced',\n",
      "            criterion='gini', max_depth=None, max_features='auto',\n",
      "            max_leaf_nodes=None, min_impurity_decrease=0.0,\n",
      "            min_impurity_split=None, min_samples_leaf=1,\n",
      "            min_samples_split=2, min_weight_fraction_leaf=0.0,\n",
      "            n_estimators=300, n_jobs=1, oob_score=False, random_state=None,\n",
      "            verbose=0, warm_start=False)\n",
      "REPORT ON CYP2C9\n",
      "TRAINING PERFORMANCE\n",
      "             precision    recall  f1-score   support\n",
      "\n",
      "          0       1.00      0.99      1.00      7759\n",
      "          1       0.99      1.00      0.99      4368\n",
      "\n",
      "avg / total       0.99      0.99      0.99     12127\n",
      "\n",
      "TEST PERFORMANCE\n",
      "             precision    recall  f1-score   support\n",
      "\n",
      "          0       0.89      0.93      0.91      1970\n",
      "          1       0.75      0.64      0.69       609\n",
      "\n",
      "avg / total       0.86      0.86      0.86      2579\n",
      "\n",
      "Generating models for CYP2C19\n",
      "Existing data file found. Not data is being generated\n",
      "Existing model foudn for CYP2C19. No model will be trained\n",
      "Generating models for CYP2D6\n",
      "(11878, 2692) (11878,)\n",
      "(2860, 2692) (2860,)\n",
      "Fitting 5 folds for each of 10 candidates, totalling 50 fits\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=-1)]: Done  41 out of  50 | elapsed:  2.0min remaining:   26.2s\n",
      "[Parallel(n_jobs=-1)]: Done  50 out of  50 | elapsed:  2.2min finished\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RandomForestClassifier(bootstrap=True, class_weight='balanced',\n",
      "            criterion='gini', max_depth=None, max_features='auto',\n",
      "            max_leaf_nodes=None, min_impurity_decrease=0.0,\n",
      "            min_impurity_split=None, min_samples_leaf=1,\n",
      "            min_samples_split=2, min_weight_fraction_leaf=0.0,\n",
      "            n_estimators=1000, n_jobs=1, oob_score=False,\n",
      "            random_state=None, verbose=0, warm_start=False)\n",
      "REPORT ON CYP2D6\n",
      "TRAINING PERFORMANCE\n",
      "             precision    recall  f1-score   support\n",
      "\n",
      "          0       1.00      0.99      1.00      9362\n",
      "          1       0.97      1.00      0.98      2516\n",
      "\n",
      "avg / total       0.99      0.99      0.99     11878\n",
      "\n",
      "TEST PERFORMANCE\n",
      "             precision    recall  f1-score   support\n",
      "\n",
      "          0       0.89      0.97      0.93      2316\n",
      "          1       0.79      0.47      0.59       544\n",
      "\n",
      "avg / total       0.87      0.88      0.86      2860\n",
      "\n",
      "Generating models for CYP3A4\n",
      "(11533, 2692) (11533,)\n",
      "(7025, 2692) (7025,)\n",
      "Fitting 5 folds for each of 10 candidates, totalling 50 fits\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=-1)]: Done  41 out of  50 | elapsed:  2.1min remaining:   27.6s\n",
      "[Parallel(n_jobs=-1)]: Done  50 out of  50 | elapsed:  2.5min finished\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RandomForestClassifier(bootstrap=True, class_weight='balanced',\n",
      "            criterion='gini', max_depth=None, max_features='auto',\n",
      "            max_leaf_nodes=None, min_impurity_decrease=0.0,\n",
      "            min_impurity_split=None, min_samples_leaf=1,\n",
      "            min_samples_split=2, min_weight_fraction_leaf=0.0,\n",
      "            n_estimators=1000, n_jobs=1, oob_score=False,\n",
      "            random_state=None, verbose=0, warm_start=False)\n",
      "REPORT ON CYP3A4\n",
      "TRAINING PERFORMANCE\n",
      "             precision    recall  f1-score   support\n",
      "\n",
      "          0       1.00      0.99      1.00      6898\n",
      "          1       0.99      1.00      0.99      4635\n",
      "\n",
      "avg / total       0.99      0.99      0.99     11533\n",
      "\n",
      "TEST PERFORMANCE\n",
      "             precision    recall  f1-score   support\n",
      "\n",
      "          0       0.84      0.84      0.84      4955\n",
      "          1       0.62      0.63      0.62      2070\n",
      "\n",
      "avg / total       0.78      0.77      0.77      7025\n",
      "\n",
      "Generating models for CYP1A2\n",
      "Existing data file found. Not data is being generated\n",
      "Existing model foudn for CYP1A2. No model will be trained\n"
     ]
    }
   ],
   "source": [
    "for _id in _ids:\n",
    "    doit(_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
